;;; guile-openai --- An OpenAI API client for Guile
;;; Copyright © 2023 Andrew Whatson <whatson@tailcall.au>
;;;
;;; This file is part of guile-openai.
;;;
;;; guile-openai is free software: you can redistribute it and/or modify
;;; it under the terms of the GNU Affero General Public License as
;;; published by the Free Software Foundation, either version 3 of the
;;; License, or (at your option) any later version.
;;;
;;; guile-openai is distributed in the hope that it will be useful, but
;;; WITHOUT ANY WARRANTY; without even the implied warranty of
;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;;; Affero General Public License for more details.
;;;
;;; You should have received a copy of the GNU Affero General Public
;;; License along with guile-openai.  If not, see
;;; <https://www.gnu.org/licenses/>.

(define-module (openai chat)
  #:use-module (openai api chat)
  #:use-module (openai client)
  #:use-module (openai utils colorized)
  #:use-module (openai utils stream)
  #:use-module (ice-9 match)
  #:use-module (json parser)
  #:use-module (srfi srfi-9)
  #:use-module (srfi srfi-9 gnu)
  #:use-module (srfi srfi-41)
  #:export (openai-default-chat-model
            openai-default-chat-temperature
            openai-default-chat-top-p

            chat?
            chat-content
            chat-stream

            call?
            call-function
            call-arguments
            call-stream

            openai-chat))

(define-once openai-default-chat-model
  (make-parameter 'gpt-3.5-turbo))

(define-once openai-default-chat-temperature
  (make-parameter *unspecified*))

(define-once openai-default-chat-top-p
  (make-parameter *unspecified*))

(define-once openai-default-chat-stream?
  (make-parameter #t))

(define-record-type <Chat>
  (%make-chat content stream)
  chat?
  (content %chat-content)
  (stream chat-stream))

(define (chat-content chat)
  (force (%chat-content chat)))

(define-record-type <Call>
  (%make-call function arguments stream)
  call?
  (function call-function)
  (arguments %call-arguments)
  (stream call-stream))

(define (call-arguments call)
  (force (%call-arguments call)))

(define (make-chat-or-call result ix)
  (if (stream? result)
      ;; handle streamed responses
      (let* ((deltas (responses->delta-stream result ix))
             (message (stream-car deltas))
             (call (chat-message-function-call message)))
        (if (unspecified? call)
            ;; make a streamed chat response
            (let* ((strm (deltas->content-stream deltas))
                   (content (delay (string-concatenate (stream->list strm)))))
              (%make-chat content strm))
            ;; make a streamed function call response
            (let* ((func (chat-function-call-name call))
                   (strm (deltas->argument-stream deltas))
                   (args (delay (json-string->scm
                                 (string-concatenate (stream->list strm))))))
              (%make-call func args strm))))
      ;; handle complete responses
      (let* ((message (response->message result ix))
             (call (chat-message-function-call message)))
        (if (unspecified? call)
            ;; make a complete chat response
            (let ((content (chat-message-content message)))
              (%make-chat (delay content) (stream content)))
            ;; make a complete function call response
            (let* ((func (chat-function-call-name call))
                   (args (chat-function-call-arguments call)))
              (%make-call func (delay args) (stream args)))))))

(define (make-chats-or-calls result n)
  (apply values (map (lambda (ix)
                       (make-chat-or-call result ix))
                     (iota n))))

(define (response->message response n)
  (chat-choice-message
   (list-ref (chat-response-choices response) n)))

(define (responses->delta-stream responses n)
  (stream-filter-map
   (lambda (response)
     (let ((choice (car (chat-response-choices response))))
       (and (eqv? n (chat-choice-index choice))
            (chat-choice-delta choice))))
   responses))

(define (deltas->argument-stream deltas)
  (stream-filter-map
   (lambda (delta)
     (let ((call (chat-message-function-call delta)))
       (and (not (unspecified? call))
            (let ((args (chat-function-call-arguments call)))
              (and (string? args) args)))))
   deltas))

(define (deltas->content-stream deltas)
  (stream-filter-map
   (lambda (delta)
     (let ((content (chat-message-content delta)))
       (and (string? content) content)))
   deltas))

(define (print-chat chat port)
  (newline port)
  (stream-for-each (lambda (content)
                     (display content port))
                   (chat-stream chat)))

(define (print-call call port)
  (newline port)
  (format port "function: ~a\n" (call-function call))
  (format port "arguments: ")
  (stream-for-each (lambda (arg-part)
                     (display arg-part port))
                   (call-stream call)))

(set-record-type-printer! <Chat> print-chat)
(set-record-type-printer! <Call> print-call)

(add-color-scheme! `((,chat? CHAT ,color-stream (GREEN BOLD))))
(add-color-scheme! `((,call? CHAT ,color-stream (GREEN BOLD))))

(define parse-prompt
  (match-lambda
    ((? null?)
     '())
    ((? pair? msgs)
     (map parse-message msgs))
    (msg
     (list (parse-message msg)))))

(define parse-message
  (match-lambda
    ((? string? msg)
     (make-chat-message "user" msg))
    (((and role (or 'system 'user 'assistant)) . (? string? msg))
     (make-chat-message (symbol->string role) msg))))

(define* (openai-chat prompt #:key
                      (model             (openai-default-chat-model))
                      (functions         *unspecified*)
                      (function-call     *unspecified*)
                      (temperature       (openai-default-chat-temperature))
                      (top-p             (openai-default-chat-top-p))
                      (n                 *unspecified*)
                      (stream?           (openai-default-chat-stream?))
                      (stop              *unspecified*)
                      (max-tokens        *unspecified*)
                      (presence-penalty  *unspecified*)
                      (frequency-penalty *unspecified*)
                      (logit-bias        *unspecified*)
                      (user              (openai-default-user)))
  "Send a chat completion request.  Returns a chat record.

The PROMPT can be a string, which will be sent as a user message.
Alternatively, prompt can be a list of `(role . content)' pairs, where
content is a string and role is a symbol `system', `user', or
`assistant'.

The keyword arguments correspond to the request parameters described
in the chat completion request documentation:

#:n - The number of responses to generate, returned as multiple
values.

#:stream? - Whether to stream the response(s), defaults to `#t'.

#:model - A symbol or string identifying the model to use.  Defaults
to `gpt-3.5-turbo', but if you're lucky you might be able to use
`gpt-4' here.

#:temperature - The sampling temperature to use, a number between 0
and 2.

#:top-p - An alternative sampling parameter, a number between 0 and 1.

#:user - An optional username to associate with this request.

The `#:stop', `#:max-tokens', `#:logit-bias', `#:presence-penalty',
`#:frequency-penalty' parameters are implemented but untested."
  (let* ((model (if (symbol? model) (symbol->string model) model))
         (prompt (parse-prompt prompt))
         (stream? (or stream? *unspecified*))
         (request (make-chat-request model prompt functions function-call
                                     temperature top-p n stream? stop max-tokens
                                     presence-penalty frequency-penalty logit-bias user))
         (response (send-chat-request request)))
    (make-chats-or-calls response (if (unspecified? n) 1 n))))
